{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "from mindspore import context\n",
    "\n",
    "parser = argparse.ArgumentParser(description='MindSpore LeNet Example')\n",
    "parser.add_argument('--device_target', type=str, default=\"CPU\", choices=['Ascend', 'GPU', 'CPU'])\n",
    "\n",
    "args = parser.parse_known_args()[0]\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "\n",
    "requests.packages.urllib3.disable_warnings()\n",
    "\n",
    "def download_dataset(dataset_url, path):\n",
    "    filename = dataset_url.split(\"/\")[-1]\n",
    "    save_path = os.path.join(path, filename)\n",
    "    if os.path.exists(save_path):\n",
    "        return\n",
    "    if not os.path.exists(path):\n",
    "        os.makedirs(path)\n",
    "    res = requests.get(dataset_url, stream=True, verify=False)\n",
    "    with open(save_path, \"wb\") as f:\n",
    "        for chunk in res.iter_content(chunk_size=512):\n",
    "            if chunk:\n",
    "                f.write(chunk)\n",
    "    print(\"The {} file is downloaded and saved in the path {} after processing\".format(os.path.basename(dataset_url), path))\n",
    "\n",
    "train_path = \"datasets/MNIST_Data/train\"\n",
    "test_path = \"datasets/MNIST_Data/test\"\n",
    "\n",
    "download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte\", train_path)\n",
    "download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte\", train_path)\n",
    "download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte\", test_path)\n",
    "download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte\", test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "import mindspore.dataset.vision.c_transforms as CV\n",
    "from mindspore.dataset.vision import Inter\n",
    "from mindspore import dtype as mstype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
    "                   num_parallel_workers=1):\n",
    "    # 定义数据集\n",
    "    mnist_ds = ds.MnistDataset(data_path)\n",
    "    resize_height, resize_width = 32, 32\n",
    "    rescale = 1.0 / 255.0\n",
    "    shift = 0.0\n",
    "    rescale_nml = 1 / 0.3081\n",
    "    shift_nml = -1 * 0.1307 / 0.3081\n",
    "\n",
    "    # 定义所需要操作的map映射\n",
    "    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
    "    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
    "    rescale_op = CV.Rescale(rescale, shift)\n",
    "    hwc2chw_op = CV.HWC2CHW()\n",
    "    type_cast_op = C.TypeCast(mstype.int32)\n",
    "\n",
    "    # 使用map映射函数，将数据操作应用到数据集\n",
    "    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns=\"label\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op], input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "\n",
    "    # 进行shuffle、batch、repeat操作\n",
    "    buffer_size = 10000\n",
    "    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)\n",
    "    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "    mnist_ds = mnist_ds.repeat(count=repeat_size)\n",
    "\n",
    "    return mnist_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore.common.initializer import Normal\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"\n",
    "    Lenet网络结构\n",
    "    \"\"\"\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        # 定义所需要的运算\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))\n",
    "        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))\n",
    "        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        # 使用定义好的运算构建前向网络\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "# 实例化网络\n",
    "net = LeNet5()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数\n",
    "net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n",
    "# 定义优化器\n",
    "net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "# 设置模型保存参数\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n",
    "# 应用模型保存参数\n",
    "ckpoint = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_net(model, data_path):\n",
    "    \"\"\"定义验证的方法\"\"\"\n",
    "    ds_eval = create_dataset(os.path.join(data_path, \"test\"))\n",
    "    acc = model.eval(ds_eval, dataset_sink_mode=False)\n",
    "    print(\"{}\".format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入模型训练需要的库\n",
    "from mindspore.nn import Accuracy\n",
    "from mindspore.train.callback import LossMonitor\n",
    "from mindspore import Model\n",
    "\n",
    "def train_net(model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):\n",
    "    \"\"\"定义训练的方法\"\"\"\n",
    "    # 加载训练数据集\n",
    "    ds_train = create_dataset(os.path.join(data_path, \"train\"), 32, repeat_size)\n",
    "    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_epoch = 1\n",
    "mnist_path = \"./datasets/MNIST_Data\"\n",
    "dataset_size = 1\n",
    "model = Model(net, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
    "train_net(model, train_epoch, mnist_path, dataset_size, ckpoint, False)\n",
    "test_net(model, mnist_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "# 加载已经保存的用于测试的模型\n",
    "param_dict = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\")\n",
    "# 加载参数到网络中\n",
    "load_param_into_net(net, param_dict)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
